{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST distributed training and batch transform\n",
    "\n",
    "The SageMaker Python SDK helps you deploy your models for training and hosting in optimized, production-ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using TensorFlow distributed training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the environment\n",
    "\n",
    "First, we'll just set up a few things needed for this example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "from sagemaker.session import Session\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "region = sagemaker_session.boto_session.region_name\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the MNIST dataset\n",
    "\n",
    "We'll now need to download the MNIST dataset, and upload it to a location in S3 after preparing for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import utils\n",
    "from tensorflow.contrib.learn.python.learn.datasets import mnist\n",
    "import tensorflow as tf\n",
    "\n",
    "data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)\n",
    "\n",
    "utils.convert_to(data_sets.train, 'train', 'data')\n",
    "utils.convert_to(data_sets.validation, 'validation', 'data')\n",
    "utils.convert_to(data_sets.test, 'test', 'data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload the data\n",
    "We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-mnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct a script for distributed training \n",
    "Here is the full code for the network model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!cat 'mnist.py'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a training job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "mnist_estimator = TensorFlow(entry_point='mnist.py',\n",
    "                             role=role,\n",
    "                             framework_version='1.12.0',\n",
    "                             training_steps=1000, \n",
    "                             evaluation_steps=100,\n",
    "                             train_instance_count=2,\n",
    "                             train_instance_type='ml.c4.xlarge')\n",
    "\n",
    "mnist_estimator.fit(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `fit()` method will create a training job in two ml.c4.xlarge instances. The logs above will show the instances doing training, evaluation, and incrementing the number of training steps. \n",
    "\n",
    "In the end of the training, the training job will generate a saved model for TF serving."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SageMaker's transformer class\n",
    "\n",
    "After training, we use our TensorFlow estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job. The method also creates a SageMaker Model to be used for the batch transform jobs.\n",
    "\n",
    "The `Transformer` class is responsible for running batch transform jobs, which will deploy the trained model to an endpoint and send requests for performing inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = mnist_estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform inference\n",
    "\n",
    "Now that we've trained a model, we're going to use it to perform inference with a SageMaker batch transform job. The request handling behavior of the Endpoint deployed during the transform job is determined by the `mnist.py` script we looked at earlier."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run a batch transform job\n",
    "\n",
    "For our batch transform job, we're going to use input data that contains 1000 MNIST images, located in the public SageMaker sample data S3 bucket. To create the batch transform job, we simply call `transform()` on our transformer with information about the input data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_bucket_name = 'sagemaker-sample-data-{}'.format(region)\n",
    "input_file_path = 'batch-transform/mnist-1000-samples'\n",
    "\n",
    "transformer.transform('s3://{}/{}'.format(input_bucket_name, input_file_path), content_type='text/csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download the results\n",
    "\n",
    "The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(transformer.output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's download the first ten results from S3:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from six.moves.urllib import parse\n",
    "\n",
    "import boto3\n",
    "\n",
    "parsed_url = parse.urlparse(transformer.output_path)\n",
    "bucket_name = parsed_url.netloc\n",
    "prefix = parsed_url.path[1:]\n",
    "\n",
    "s3 = boto3.resource('s3')\n",
    "\n",
    "predictions = []\n",
    "for i in range(10):\n",
    "    file_key = '{}/data-{}.csv.out'.format(prefix, i)\n",
    "\n",
    "    output_obj = s3.Object(bucket_name, file_key)\n",
    "    output = output_obj.get()[\"Body\"].read().decode('utf-8')\n",
    "\n",
    "    predictions.extend(json.loads(output)['outputs']['classes']['int64Val'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demonstration purposes, we're also going to download the corresponding original input data so that we can see how the model did with its predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy import genfromtxt\n",
    "\n",
    "plt.rcParams['figure.figsize'] = (2,10)\n",
    "\n",
    "def show_digit(img, caption='', subplot=None):\n",
    "    if subplot == None:\n",
    "        _,(subplot) = plt.subplots(1,1)\n",
    "    imgr = img.reshape((28,28))\n",
    "    subplot.axis('off')\n",
    "    subplot.imshow(imgr, cmap='gray')\n",
    "    plt.title(caption)\n",
    "\n",
    "tmp_dir = '/tmp/data'\n",
    "if not os.path.exists(tmp_dir):\n",
    "    os.makedirs(tmp_dir)\n",
    "\n",
    "for i in range(10):\n",
    "    input_file_name = 'data-{}.csv'.format(i)\n",
    "    input_file_key = '{}/{}'.format(input_file_path, input_file_name)\n",
    "    \n",
    "    s3.Bucket(input_bucket_name).download_file(input_file_key, os.path.join(tmp_dir, input_file_name))\n",
    "    input_data = genfromtxt(os.path.join(tmp_dir, input_file_name), delimiter=',')\n",
    "\n",
    "    show_digit(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we can see the original labels are:\n",
    "\n",
    "```\n",
    "7, 2, 1, 0, 4, 1, 4, 9, 5, 9\n",
    "```\n",
    "\n",
    "Now let's print out the predictions to compare:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(', '.join(predictions))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_tensorflow_p36",
   "language": "python",
   "name": "conda_tensorflow_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
